import os.path
import pickle
import torch.utils.data
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets
import numpy as np
from PIL import Image


class CifarC(torch.utils.data.Dataset):
    test_list = "test_batch"
    test_set_size = 10000
    def __init__(self, data_path, transform=None, level: int = 5, original_labels_path = None):
        self.data = np.load(data_path)[(level-1)*self.test_set_size: level*self.test_set_size]
        self.transform = transform
        # now load the picked numpy arrays
        file_path = os.path.join(original_labels_path, self.test_list)
        with open(file_path, "rb") as f:
            entry = pickle.load(f, encoding="latin1")
            if "labels" in entry:
                self.targets = entry["labels"]
            else:
                self.targets = entry["fine_labels"]
    
    def __len__(self):
        return self.test_set_size
    
    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        return img, target